knitr::opts_chunk$set(echo = TRUE, 
                      warning = FALSE, 
                      message = FALSE, 
                      fig.align = 'center')
knitr::opts_knit$set(root.dir = "/mnt/raid5/Personal/tangchao/project/Nanopore/BarcodeDecomplex")
library(data.table)
library(ggplot2)
library(caret)
library(multiROC)
library(pROC)
library(parallel)
library(cowplot)
library(patchwork)

2 Y

Performs <- list()
for(i in c(2, 4, 6, 8, 10, 12, 14, 16, 46)) {
  load(paste0("./analysis/11.Classifiers/01.ModelingData/Barcode_", i, ".RData"))
  Fit1 <- readRDS(paste0("./analysis/11.Classifiers/02.Classifier/Barcode_", i, ".Rds"))

  ROC_Tab <- data.frame(predict(Fit1, TestData, type = "prob"))
  colnames(ROC_Tab) <- gsub("\\.", "-", colnames(ROC_Tab))
  mcroc <- pROC::multiclass.roc(TestData$Class, ROC_Tab)

  rf_pred <- data.frame(predict(Fit1, TestData, type = "prob"))
  colnames(rf_pred) <- paste0(colnames(rf_pred), "_pred_RF")

  true_label <- dummies::dummy(TestData$Class, sep = " ")
  true_label <- data.frame(true_label)
  colnames(true_label) <- gsub("Class.", "", colnames(true_label))
  colnames(true_label) <- paste0(colnames(true_label), "_true")

  final_df <- cbind(true_label, rf_pred)

  roc_res <- multi_roc(final_df, force_diag = T)
  pr_res <- multi_pr(final_df, force_diag = T)

  plot_roc_df <- as.data.table(plot_roc_data(roc_res))
  plot_pr_df <- as.data.table(plot_pr_data(pr_res))

  ROC_Tab$pred <- predict(Fit1, newdata = TestData)
  ROC_Tab$PP <- apply(ROC_Tab[, grepl("RTA", colnames(ROC_Tab))], 1, max)
  ROC_Tab$true <- TestData$Class
  setDT(ROC_Tab)

  Performs[[i]] <- list(Pred = ROC_Tab, ROC = plot_roc_df, PR = plot_pr_df, 
                        AUC = data.table(pROC = auc(mcroc), Macro = roc_res$AUC$RF$macro, Micro = roc_res$AUC$RF$micro))
  rm(list = c("ROC_Tab", "plot_roc_df", "plot_pr_df", "mcroc", "roc_res"))
}
Performs <- Performs[c(2, 4, 6, 8, 10, 12, 14, 16, 46)]

barcode 24

load(file = "./analysis/04.RandomForest/01.ClassifierTraining/BIN_100_24barcodes_Classifier_V2.RData")
load("./analysis/04.RandomForest/01.ClassifierTraining/RData/BIN_100_24barcodes_TestData_V2.RData")
ROC_Tab <- data.frame(obs = TestData$Class, 
                      predict(Fit1, TestData, type = "prob"), 
                      pred = predict(Fit1, newdata = TestData))

colnames(ROC_Tab) <- gsub("\\.", "-", colnames(ROC_Tab))
mcroc <- pROC::multiclass.roc(TestData$Class, ROC_Tab)

rf_pred <- data.frame(ROC_Tab[, grepl("RTA", colnames(ROC_Tab))])
colnames(rf_pred) <- paste0(colnames(rf_pred), "_pred_RF")

true_label <- dummies::dummy(TestData$Class, sep = " ")
true_label <- data.frame(true_label)
colnames(true_label) <- gsub("Class.", "", colnames(true_label))
colnames(true_label) <- paste0(colnames(true_label), "_true")

final_df <- cbind(true_label, rf_pred)

roc_res <- multi_roc(final_df, force_diag = T)
pr_res <- multi_pr(final_df, force_diag = T)

plot_roc_df <- as.data.table(plot_roc_data(roc_res))
plot_pr_df <- as.data.table(plot_pr_data(pr_res))

ROC_Tab$PP <- apply(ROC_Tab[, grepl("RTA", colnames(ROC_Tab))], 1, max)
setDT(ROC_Tab)
setnames(ROC_Tab, "obs", "true")

Performs[[6]] <- list(Pred = ROC_Tab, ROC = plot_roc_df, PR = plot_pr_df, 
                      AUC = data.table(pROC = auc(mcroc), Macro = roc_res$AUC$RF$macro, Micro = roc_res$AUC$RF$micro))
saveRDS(Performs, file = "./analysis/11.Classifiers/03.Performance/00.RData/Performs.Rds")

AUC

Performs <- readRDS("./analysis/11.Classifiers/03.Performance/00.RData/Performs.Rds")
AUCs <- data.table(Barcodes = c(2, 4, 6, 8, 10, 24), AUC = mapply(function(x) x$AUC$Micro, Performs))
ggplot(AUCs, aes(x = Barcodes, y = AUC)) + 
  geom_path()
Cutoff_List <- lapply(Performs, function(x) {
  do.call(rbind, lapply(0:100/100, function(p) {
    x$Pred[PP >= p, .(Cutoff = p, Accuracy = mean(pred == true))]
  })) -> res
  res$Recovery <- mapply(0:100/100, FUN = function(p) x$Pred[, mean(PP >= p)])
  return(res)
})
Cutoff_List <- data.table(Barcodes = rep(c(2, 4, 6, 8, 10, 24), mapply(nrow, Cutoff_List)), do.call(rbind, Cutoff_List))
ggplot(Cutoff_List, aes(x = Recovery * 100, y = Accuracy * 100, color = as.factor(Barcodes))) + 
  geom_path() + 
  scale_x_reverse() + 
  theme_bw(base_size = 15) + 
  scale_color_brewer(palette = "Dark2") + 
  guides(color = guide_legend(title = "Barcodes")) + 
  labs(x = "Recovery (%)", y = "Accuracy (%)") + 
  geom_point(data = Cutoff_List[Cutoff %in% c(0.3, 0.5)], 
             mapping = aes(x = Recovery * 100, y = Accuracy * 100, shape = as.factor(Cutoff)), size = 2) + 
  guides(shape = guide_legend(title = "Cutoff")) + 
  scale_shape_manual(values = c(16, 3))

ggsave("./analysis/11.Classifiers/03.Performance/Trade_off.pdf", width = 6, height = 5)
names(Performs) <- c(2, 4, 6, 8, 10, 24)
for(j in seq_along(Performs)) {
  print(j)
  res <- Performs[[j]]
  plot_roc_df <- res$ROC
  plot_pr_df <- res$PR
  ggplot(plot_roc_df[Group == "Micro"], aes(x = 1 - Specificity, y = Sensitivity)) + 
    geom_path(size = 1, color = "#E41A1C") + 
    annotate(geom = "text", x = 0.7, y = 0.1, size = 5,
             label = paste0("AUC = ", round(res$AUC$Micro, 3))) + 
    theme_bw(base_size = 15) -> p1

  ggplot(plot_pr_df[Group == "Micro"], aes(x = Recall, y = Precision)) + 
    geom_path(size = 1, color = "#377EB8") + 
    theme_bw() +
    ylim(c(0, 1)) + 
    theme_bw(base_size = 15) -> p2

  do.call(rbind, lapply(0:100/100, function(p) {
    res$Pred[PP >= p, .(Cutoff = p, Accuracy = mean(pred == true))]
  })) -> Cutoffs
  Cutoffs$Recovery <- mapply(0:100/100, FUN = function(p) res$Pred[, mean(PP >= p)])

  ggplot(Cutoffs, aes(x = Cutoff)) + 
    geom_path(aes(y = Accuracy * 100), size = 1, color = "#1B9E77") + 
    labs(y = "Accuracy (%)") + 
    scale_x_continuous(position = "top") + 
    theme_half_open(font_size = 15) +
    theme(axis.text.y.left = element_text(color = "#1B9E77"), 
          axis.title.y.left = element_text(color = "#1B9E77"), 
          axis.ticks.y = element_line(color = "#1B9E77"), 
          axis.title.x = element_blank(), 
          axis.text.x = element_blank(), 
          axis.ticks.x = element_blank()) -> g1

  ggplot(Cutoffs, aes(x = Cutoff)) + 
    geom_path(aes(y = Recovery * 100), size = 1, color = "#D95F02") + 
    scale_y_continuous(position = "right") + 
    labs(y = "Recovery (%)") + 
    theme_half_open(font_size = 15) + 
    theme(axis.text.y = element_text(color = "#D95F02"), 
          axis.title.y = element_text(color = "#D95F02"), 
          panel.grid.major = element_line(color = "grey90"),
          panel.grid.minor = element_line(color = "grey95"), 
          axis.ticks.y.right = element_line(color = "#D95F02")) -> g2

  aligned_plots <- align_patches(plotlist = list(g1, g2), align="hv", axis="tblr")
  p7_2 <- ggdraw(aligned_plots[[2]]) + draw_plot(aligned_plots[[1]])

  ggsave(p1, filename = paste0("./analysis/11.Classifiers/03.Performance/AUC_barcodes_", names(Performs)[j], ".pdf"), width = 3, height = 3)
  ggsave(p2, filename = paste0("./analysis/11.Classifiers/03.Performance/PR_barcodes_", names(Performs)[j], ".pdf"), width = 3, height = 3)
  ggsave(p7_2, filename = paste0("./analysis/11.Classifiers/03.Performance/Cutoff_barcodes_", names(Performs)[j], ".pdf"), width = 3.64, height = 3)
}
p1s <- lapply(seq_along(Performs), function(j) {
  res <- Performs[[j]]
  plot_roc_df <- res$ROC
  ggplot(plot_roc_df[Group == "Micro"], aes(x = 1 - Specificity, y = Sensitivity)) + 
    geom_path(size = 1, color = "#E41A1C") + 
    annotate(geom = "text", x = 0.7, y = 0.1, size = 5,
             label = paste0("AUC = ", round(res$AUC$Micro, 3))) + 
    theme_bw(base_size = 15) + 
    labs(title = paste0(names(Performs)[j], " barcodes")) -> p1
  return(p1)
})
p1p <- cowplot::plot_grid(plotlist = p1s, nrow = 2)
ggsave("./analysis/11.Classifiers/03.Performance/AUC_barcodes.pdf", p1p, width = 9, height = 6)
p2s <- lapply(seq_along(Performs), function(j) {
  res <- Performs[[j]]
  plot_pr_df <- res$PR
  ggplot(plot_pr_df[Group == "Micro"], aes(x = Recall, y = Precision)) + 
    geom_path(size = 1, color = "#377EB8") + 
    theme_bw() +
    ylim(c(0, 1)) + 
    theme_bw(base_size = 15) + 
    labs(title = paste0(names(Performs)[j], " barcodes")) -> p2
  return(p2)
})
p2p <- cowplot::plot_grid(plotlist = p2s, nrow = 2)
ggsave("./analysis/11.Classifiers/03.Performance/PR_barcodes.pdf", p2p, width = 9, height = 6)
p3s <- lapply(seq_along(Performs), function(j) {
  res <- Performs[[j]]
  do.call(rbind, lapply(0:100/100, function(p) {
    res$Pred[PP >= p, .(Cutoff = p, Accuracy = mean(pred == true))]
  })) -> Cutoffs
  Cutoffs$Recovery <- mapply(0:100/100, FUN = function(p) res$Pred[, mean(PP >= p)])

  ggplot(Cutoffs, aes(x = Cutoff)) + 
    geom_path(aes(y = Accuracy * 100), size = 1, color = "#1B9E77") + 
    labs(y = "Accuracy (%)") + 
    scale_x_continuous(position = "top") + 
    theme_half_open(font_size = 15) +
    theme(axis.text.y.left = element_text(color = "#1B9E77"), 
          axis.title.y.left = element_text(color = "#1B9E77"), 
          axis.ticks.y = element_line(color = "#1B9E77"), 
          axis.title.x = element_blank(), 
          axis.text.x = element_blank(), 
          axis.ticks.x = element_blank()) + 
    labs(title = paste0(names(Performs)[j], " barcodes")) -> g1

  ggplot(Cutoffs, aes(x = Cutoff)) + 
    geom_path(aes(y = Recovery * 100), size = 1, color = "#D95F02") + 
    scale_y_continuous(position = "right") + 
    labs(y = "Recovery (%)") + 
    theme_half_open(font_size = 15) + 
    theme(axis.text.y = element_text(color = "#D95F02"), 
          axis.title.y = element_text(color = "#D95F02"), 
          panel.grid.major = element_line(color = "grey90"),
          panel.grid.minor = element_line(color = "grey95"), 
          axis.ticks.y.right = element_line(color = "#D95F02")) -> g2

  aligned_plots <- align_patches(plotlist = list(g1, g2), align="hv", axis="tblr")
  p7_2 <- ggdraw(aligned_plots[[2]]) + draw_plot(aligned_plots[[1]])
  return(p7_2)
})
p3p <- cowplot::plot_grid(plotlist = p3s, nrow = 2)
ggsave("./analysis/11.Classifiers/03.Performance/Cutoff_barcodes.pdf", p3p, width = 10.92, height = 6)
PR_AUC <- mapply(Performs, FUN = function(x) {
  x$PR[Group == "Macro", unique(AUC)]
})
names(PR_AUC) <- c(2, 4, 6, 8, 10, 24)
PR_AUC


tangchao7498/DecodeR documentation built on May 27, 2023, 7:21 p.m.